from matplotlib import pyplot as plt


def draw(draw_data):
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(draw_data["epoch"], draw_data["accTrain"], label="train acc")
    plt.plot(draw_data["epoch"], draw_data["accVal"], label="val acc")
    plt.xlabel("epoch")
    plt.ylabel("Accuracy")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(draw_data["epoch"], draw_data["lossTrain"], label="train loss")
    plt.plot(draw_data["epoch"], draw_data["lossVal"], label="val loss")
    plt.xlabel("epoch")
    plt.ylabel("Loss")
    plt.legend()

    plt.savefig("train.png", dpi=300, bbox_inches="tight")
    print("保存成功：训练结果.png")

    plt.show()

